Amazon SageMakerのSequence2Sequenceを使って機械翻訳する

Amazon SageMakerのSequence2Sequenceを使って機械翻訳する

Clock Icon2018.08.13

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、小澤です。

当エントリではAmazon SageMakerの組み込みアルゴリズムの1つである「Sequence to Sequence」についての解説を書かせていただきます。

目次

Sequence to Sequenceとは

Sequence to Sequence(以下seq2seq)とは、Deep Learningの手法の1つでRNN(Recurrent Neural Network)ベースのネットワークとなっています。

一般的なフィードフォワード型のニューラルネットワークでは、入力層から出力層に向けて、一方通行のネットワーク構造になります。 それに対して、RNNでは自身の出力が自身の入力となる閉路を持つようなネットワークになります。

このネットワークがどのような動きをするか、というと無限に計算をし続けしまう...というわけではなく、系列データにおいて、1つ前の出力が次の入力として渡されるような動きとなります。 RNNではこの流れを展開した以下のような形でも表現できます。

これは、テキストデータであれば、文章中の1単語ずつ順に入力していくことで以前の単語がなんであったかも含めたネットワークが作成できることを意味します。 このRNN自体は、入力して単語を入れると次の単語を出力する -> その出力を入力として与えるとさらに次の単語を出力する、という文章生成に用いることが可能です。

これに対して、入力部分をEncoder、出力部分をDecoderとして、単語単位ではなく文章単位で扱ってしまうようにしたのがseq2seqとなります。

(画像 : A Neural Conversational Modelより)

実用的には、ここで記載した内容の他にも様々な工夫がされています。 それらの詳細を知りたい方は参考文献をご覧ください。

Exampleを実行してみる

SageMakerのExamplesにあるノートブックでどのような実装が必要になるのか見ていきましょう。 今回利用するExampleは英語をドイツ語に翻訳するものになっています。

利用するExample

今回利用するExampleはノートブックの上部メニューから SageMaker Examples > Introduction to Amazon algorithms > SageMaker-Seq2Seq-Translation-English-German.ipynb とたどったものになります。

データの準備

では、ここから実際にExampleを動かしてみましょう。

まず最初はいつも通り、バケットの設定や必要なライブラリのインポートを行います。

# S3 bucket and prefix
bucket = '<your_s3_bucket_name_here>'
prefix = 'sagemaker/DEMO-seq2seq'
import boto3
import re
from sagemaker import get_execution_role

role = get_execution_role()
from time import gmtime, strftime
import time
import numpy as np
import os
import json

# For plotting attention matrix later on
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

続いて、データのダウンロードを行います。

%%bash
wget http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/corpus.tc.de.gz & \
wget http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/corpus.tc.en.gz & wait
gunzip corpus.tc.de.gz & \
gunzip corpus.tc.en.gz & wait
mkdir validation
curl http://data.statmt.org/wmt17/translation-task/preprocessed/de-en/dev.tgz | tar xvzf - -C validation

データは、WMT17のものを使ってるようです。 続いて、このデータからそれぞれ先頭10000行を取り出します。

!head -n 10000 corpus.tc.en > corpus.tc.en.small
!head -n 10000 corpus.tc.de > corpus.tc.de.small

seq2seqの学習には非常に時間がかかるので、今回はExampleということで学習時間を少なくするために、データ件数を減らしているようです。 続いて、pythonプログラムを実行して以下の操作を行なっています。

  • データをrecordIO-protobuf形式に変換
  • ボキャブラリの作成

前者はデータを学習時に利用可能な形式に変換してるものになります。 後者は、文章中の単語にIDを付与してそのペアを作成する処理になっています。

%%time
%%bash
python3 create_vocab_proto.py \
        --train-source corpus.tc.en.small \
        --train-target corpus.tc.de.small \
        --val-source validation/newstest2014.tc.en \
        --val-target validation/newstest2014.tc.de

この操作によって翻訳元のデータである英語と翻訳先であるドイツ語それぞれのrecordIO-protobufファイルとボキャブラリファイルが作成されます。 create_vocab_proto.pyの実装の詳細には立ち入りませんので、気になる方は内容をご確認ください。

学習の実行

では、いよいよ学習を行います。

まずは学習に利用するコンテナの設定を行います。

region_name = boto3.Session().region_name
from sagemaker.amazon.amazon_estimator import get_image_uri
container = get_image_uri(region_name, 'seq2seq')

print('Using SageMaker Seq2Seq container: {} ({})'.format(container, region_name))

2018/08/13現在、東京リージョンを使っている場合は以下のコンテナが利用されます。

Using SageMaker Seq2Seq container: 501404015308.dkr.ecr.ap-northeast-1.amazonaws.com/seq2seq:1 (ap-northeast-1)

続いて、パラメータ設定を行なって学習処理を実行します。

job_name = 'DEMO-seq2seq-en-de-' + strftime("%Y-%m-%d-%H", gmtime())
print("Training job", job_name)

create_training_params = \
{
    "AlgorithmSpecification": {
        "TrainingImage": container,
        "TrainingInputMode": "File"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": "s3://{}/{}/".format(bucket, prefix)
    },
    "ResourceConfig": {
        # Seq2Seq does not support multiple machines. Currently, it only supports single machine, multiple GPUs
        "InstanceCount": 1,
        "InstanceType": "ml.p2.xlarge", # We suggest one of ["ml.p2.16xlarge", "ml.p2.8xlarge", "ml.p2.xlarge"]
        "VolumeSizeInGB": 50
    },
    "TrainingJobName": job_name,
    "HyperParameters": {
        # Please refer to the documentation for complete list of parameters
        "max_seq_len_source": "60",
        "max_seq_len_target": "60",
        "optimized_metric": "bleu",
        "batch_size": "64", # Please use a larger batch size (256 or 512) if using ml.p2.8xlarge or ml.p2.16xlarge
        "checkpoint_frequency_num_batches": "1000",
        "rnn_num_hidden": "512",
        "num_layers_encoder": "1",
        "num_layers_decoder": "1",
        "num_embed_source": "512",
        "num_embed_target": "512",
        "checkpoint_threshold": "3",
        "max_num_batches": "2100"
        # Training will stop after 2100 iterations/batches.
        # This is just for demo purposes. Remove the above parameter if you want a better model.
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 48 * 3600
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://{}/{}/train/".format(bucket, prefix),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
        },
        {
            "ChannelName": "vocab",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://{}/{}/vocab/".format(bucket, prefix),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": "s3://{}/{}/validation/".format(bucket, prefix),
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
        }
    ]
}

sagemaker_client = boto3.Session().client(service_name='sagemaker')
sagemaker_client.create_training_job(**create_training_params)

status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print(status)

このExampleでは、SageMaker Python SDKではなく、boto3を使っています。 SageMaker Python SDKを利用する場合は、

  • AlgorithmSpecification ~ TrainingJobNameはEstimator作成時
  • HyperParametersはset_hyperparamers関数
  • InputDataConfigはfit関数

で指定する項目となります。 なお、コメントにもある通り、seq2seqの学習は分散処理できないようです。 そのため、InstanceCountは必ず1とします。

また、ハイパーパラメータの詳細については、ドキュメントをご覧ください。

boto3を使った学習処理の実行では、すぐにプログラム側に制御を返します。 次のコードでは、現在のステータスを取得して、失敗した場合は例外を投げます。

status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print(status)
# if the job failed, determine why
if status == 'Failed':
    message = sagemaker_client.describe_training_job(TrainingJobName=job_name)['FailureReason']
    print('Training failed with the following error: {}'.format(message))
    raise Exception('Training job failed')

処理状況をプログラムからトラッキングしたい場合は、これを一定間隔で実行するようなコードを実装すればいいでしょう。 実際の処理状況を確認したい場合は、マネジメントコンソールからCloudWatch Logsをたどるなどしてください。

エンドポイント作成

SageMaker Python SDKでは、Estimatorのdeploy関数を使うことで以下の3つの処理を行なっていました。

  • モデルの作成
  • エンドポイント設定の作成
  • エンドポイントの作成

ここでは、それら3つの処理を個別に行なっています。

最初にembedding層に事前学習済みのモデルを利用するかの設定を行なっています。

use_pretrained_model = False

密な単語ベクトルなど、別途解説が必要な項目となるため、詳細に関しては割愛します。 詳しく知りたい方はドキュメントを参照してください。

続いて、モデルの作成を行います。

%%time

sage = boto3.client('sagemaker')

if not use_pretrained_model:
    info = sage.describe_training_job(TrainingJobName=job_name)
    model_name=job_name
    model_data = info['ModelArtifacts']['S3ModelArtifacts']

print(model_name)
print(model_data)

primary_container = {
    'Image': container,
    'ModelDataUrl': model_data
}

print(primary_container)

create_model_response = sage.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container)

print(create_model_response['ModelArn'])

モデルの作成が終わったら、エンドポイント設定の作成を行います。

from time import gmtime, strftime

endpoint_config_name = 'DEMO-Seq2SeqEndpointConfig-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(endpoint_config_name)
create_endpoint_config_response = sage.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.m4.xlarge',
        'InitialInstanceCount':1,
        'ModelName':model_name,
        'VariantName':'AllTraffic'}])

print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

その後、このエンドポント設定を使ってエンドポイントの作成を行います。

%%time
import time

endpoint_name = 'DEMO-Seq2SeqEndpoint-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(endpoint_name)
create_endpoint_response = sage.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name)
print(create_endpoint_response['EndpointArn'])

resp = sage.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)

# wait until the status has changed
sage.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)

# print the status of the endpoint
endpoint_response = sage.describe_endpoint(EndpointName=endpoint_name)
status = endpoint_response['EndpointStatus']
print('Endpoint creation ended with EndpointStatus = {}'.format(status))

if status != 'InService':
    raise Exception('Endpoint creation failed.')

作成が完了したら、この後呼び出すためにboto3のクライアントを作成しておきます。

runtime = boto3.client(service_name='runtime.sagemaker') 

結果の確認

では、早速いくつかの文章を入れて翻訳してみましょう。

sentences = ["you are so good !",
             "can you drive a car ?",
             "i want to watch a movie ."
            ]

payload = {"instances" : []}
for sent in sentences:
    payload["instances"].append({"data" : sent})

response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/json', 
                                   Body=json.dumps(payload))

response = response["Body"].read().decode("utf-8")
response = json.loads(response)
print(response)

sentencesで指定した3つの文章を翻訳しています。 結果は以下のようになっています。

{
'predictions': [
{
'target': 'grenzüberschreitende Zusammenarbeit in Betracht .'
}, {
'target': 'wer sind Sie werden .'
}, {
'target': '1 US @-@ Dollar war die Bank auf den Kapitalmärkten .'
}
]
}

3つの文章に対して、それぞれ3つの結果が返ってきています。 しかし、私はドイツ語がわからないため、この結果がどれだけいいものなのかさっぱりわかりませんでしたww

SageMakerのseq2seqではAttentionという機構が取り入れられています。 これは、大雑把にいうと翻訳の際にどの単語同士の対応関係が重要なのかを示すようなものになっています。 続いての処理では、この翻訳前後単語の対応関係がどのようなものになっているのかを可視化しています。

まず、単一の文章を渡して翻訳を行います。

sentence = 'can you drive a car ?'

# "configuration" : {"attention_matrix":"true"}
# でAttiontion Matrixを取得している
payload = {"instances" : [{
                            "data" : sentence,
                            "configuration" : {"attention_matrix":"true"}
                          }
                         ]}

response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                 ContentType='application/json', 
                                   Body=json.dumps(payload))

response = response["Body"].read().decode("utf-8")
response = json.loads(response)['predictions'][0]

source = sentence
target = response["target"]
attention_matrix = np.array(response["matrix"])

print("Source: %s \nTarget: %s" % (source, target))

続いて、これを可視化します。

# Define a function for plotting the attentioan matrix
def plot_matrix(attention_matrix, target, source):
    source_tokens = source.split()
    target_tokens = target.split()
    assert attention_matrix.shape[0] == len(target_tokens)
    plt.imshow(attention_matrix.transpose(), interpolation="nearest", cmap="Greys")
    plt.xlabel("target")
    plt.ylabel("source")
    plt.gca().set_xticks([i for i in range(0, len(target_tokens))])
    plt.gca().set_yticks([i for i in range(0, len(source_tokens))])
    plt.gca().set_xticklabels(target_tokens)
    plt.gca().set_yticklabels(source_tokens)
    plt.tight_layout()

plot_matrix(attention_matrix, target, source)

結果は以下のようになります。

canとsind、youとSieあたりの対応が強いようです。

この後、Exampleでは推論時のデータをProtobut形式のデータでやりとりするコードが記述されていますが、 長くなってきたので今回は割愛させていただきます。

最後に、エンドポイントの削除を行なっておきましょう。

sage.delete_endpoint(EndpointName=endpoint_name)

おわりに

今回はSageMakerの組み込みアルゴリズムのうち、Sequence to SequenceについてExampleの流れを追ってみました。 日本語とかで同じことができたらなかなか楽しそうですねw

今回のExampleではデータ件数を5852458行のうち先頭10000行と大幅に絞っていたりします。 (ドイツ語はわからないですがなんとなく)あまりちゃんと訳せていないようにも見受けられます。 それでも約30分程度処理には時間がかかりました。 seq2seqで翻訳システム同等の精度を出そうとすると、必要なデータ件数も非常に多く必要で、学習にも時間がかかります。 そういった点から、自力でやるのはなかなか難しいかもしれませんが、現代の機械翻訳システムがどのようになっているのかを知るにはちょうどいいExampleになっているかと思います。

参考文献

論文

書籍

Webサイト

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.